import asyncio
import os
import re
import uuid
from functools import partial
from typing import Dict, List, Optional, Union

import httpx
from slack_bolt.adapter.fastapi.async_handler import AsyncSlackRequestHandler
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
from slack_bolt.async_app import AsyncApp

from chainlit.config import config
from chainlit.context import ChainlitContext, HTTPSession, context, context_var
from chainlit.data import get_data_layer
from chainlit.element import Element, ElementDict
from chainlit.emitter import BaseChainlitEmitter
from chainlit.logger import logger
from chainlit.message import Message, StepDict
from chainlit.types import Feedback
from chainlit.user import PersistedUser, User
from chainlit.user_session import user_session


class SlackEmitter(BaseChainlitEmitter):
    def __init__(
        self,
        session: HTTPSession,
        app: AsyncApp,
        channel_id: str,
        say,
        thread_ts: Optional[str] = None,
    ):
        super().__init__(session)
        self.app = app
        self.channel_id = channel_id
        self.say = say
        self.thread_ts = thread_ts

    async def send_element(self, element_dict: ElementDict):
        if element_dict.get("display") != "inline":
            return

        persisted_file = self.session.files.get(element_dict.get("chainlitKey") or "")
        file: Optional[Union[bytes, str]] = None

        if persisted_file:
            file = str(persisted_file["path"])
        elif file_url := element_dict.get("url"):
            async with httpx.AsyncClient() as client:
                response = await client.get(file_url)
                if response.status_code == 200:
                    file = response.content

        if not file:
            return

        await self.app.client.files_upload_v2(
            channel=self.channel_id,
            thread_ts=self.thread_ts,
            file=file,
            title=element_dict.get("name"),
        )

    async def send_step(self, step_dict: StepDict):
        step_type = step_dict.get("type")
        is_assistant_message = step_type == "assistant_message"
        is_empty_output = not step_dict.get("output")

        if is_empty_output or not is_assistant_message:
            return

        enable_feedback = get_data_layer()
        blocks: List[Dict] = [
            {
                "type": "section",
                "text": {"type": "mrkdwn", "text": step_dict["output"]},
            }
        ]
        if enable_feedback:
            current_run = context.current_run
            scorable_id = current_run.id if current_run else step_dict.get("id")
            blocks.append(
                {
                    "type": "actions",
                    "elements": [
                        {
                            "action_id": "thumbdown",
                            "type": "button",
                            "text": {
                                "type": "plain_text",
                                "emoji": True,
                                "text": ":thumbsdown:",
                            },
                            "value": scorable_id,
                        },
                        {
                            "action_id": "thumbup",
                            "type": "button",
                            "text": {
                                "type": "plain_text",
                                "emoji": True,
                                "text": ":thumbsup:",
                            },
                            "value": scorable_id,
                        },
                    ],
                }
            )
        await self.say(
            text=step_dict["output"], blocks=blocks, thread_ts=self.thread_ts
        )

    async def update_step(self, step_dict: StepDict):
        is_assistant_message = step_dict["type"] == "assistant_message"

        if not is_assistant_message:
            return

        await self.send_step(step_dict)


slack_app = AsyncApp(
    token=os.environ.get("SLACK_BOT_TOKEN"),
    signing_secret=os.environ.get("SLACK_SIGNING_SECRET"),
)


async def start_socket_mode():
    """
    Initializes and starts the Slack app in Socket Mode asynchronously.

    Uses the SLACK_WEBSOCKET_TOKEN from environment variables to authenticate.
    """
    handler = AsyncSocketModeHandler(slack_app, os.environ.get("SLACK_WEBSOCKET_TOKEN"))
    await handler.start_async()


def init_slack_context(
    session: HTTPSession,
    slack_channel_id: str,
    event,
    say,
    thread_ts: Optional[str] = None,
) -> ChainlitContext:
    emitter = SlackEmitter(
        session=session,
        app=slack_app,
        channel_id=slack_channel_id,
        say=say,
        thread_ts=thread_ts,
    )
    context = ChainlitContext(session=session, emitter=emitter)
    context_var.set(context)
    user_session.set("slack_event", event)
    user_session.set(
        "fetch_slack_message_history",
        partial(
            fetch_message_history, channel_id=slack_channel_id, thread_ts=thread_ts
        ),
    )
    return context


slack_app_handler = AsyncSlackRequestHandler(slack_app)

users_by_slack_id: Dict[str, Union[User, PersistedUser]] = {}

USER_PREFIX = "slack_"


bot_user_id: Optional[str] = None


async def get_bot_user_id() -> Optional[str]:
    """Get and cache the bot's user ID."""
    global bot_user_id
    if bot_user_id:
        return bot_user_id

    try:
        result = await slack_app.client.auth_test()
        if result.get("ok"):
            bot_user_id = result.get("user_id")
            return bot_user_id
    except Exception as e:
        logger.error(f"Failed to get bot user ID: {e}")

    return None


def clean_content(message: str):
    cleaned_text = re.sub(r"<@[\w]+>", "", message).strip()
    return cleaned_text


async def get_user(slack_user_id: str):
    if slack_user_id in users_by_slack_id:
        return users_by_slack_id[slack_user_id]

    slack_user = await slack_app.client.users_info(user=slack_user_id)
    slack_user_profile = slack_user["user"]["profile"]

    user_identifier = slack_user_profile.get("email") or slack_user_id
    user = User(identifier=USER_PREFIX + user_identifier, metadata=slack_user_profile)

    users_by_slack_id[slack_user_id] = user

    if data_layer := get_data_layer():
        try:
            persisted_user = await data_layer.create_user(user)
            if persisted_user:
                users_by_slack_id[slack_user_id] = persisted_user
        except Exception as e:
            logger.error(f"Error creating user: {e}")

    return users_by_slack_id[slack_user_id]


async def fetch_message_history(
    channel_id: str, thread_ts: Optional[str] = None, limit=30
):
    if not thread_ts:
        result = await slack_app.client.conversations_history(
            channel=channel_id, limit=limit
        )
    else:
        result = await slack_app.client.conversations_replies(
            channel=channel_id, ts=thread_ts, limit=limit
        )
    if result["ok"]:
        messages = result["messages"]
        return messages
    else:
        raise Exception(f"Failed to fetch messages: {result['error']}")


async def download_slack_file(url, token):
    headers = {"Authorization": f"Bearer {token}"}
    async with httpx.AsyncClient() as client:
        response = await client.get(url, headers=headers)
        if response.status_code == 200:
            return response.content
        else:
            return None


async def download_slack_files(session: HTTPSession, files, token):
    download_coros = [
        download_slack_file(file.get("url_private"), token) for file in files
    ]
    file_bytes_list = await asyncio.gather(*download_coros)
    file_refs = []
    for idx, file_bytes in enumerate(file_bytes_list):
        if file_bytes:
            name = files[idx].get("name")
            mime_type = files[idx].get("mimetype")
            file_ref = await session.persist_file(
                name=name, mime=mime_type, content=file_bytes
            )
            file_refs.append(file_ref)

    files_dicts = [
        session.files[file["id"]] for file in file_refs if file["id"] in session.files
    ]

    elements = [
        Element.from_dict(
            {
                "id": file["id"],
                "name": file["name"],
                "path": str(file["path"]),
                "chainlitKey": file["id"],
                "display": "inline",
                "type": Element.infer_type_from_mime(file["type"]),
            }
        )
        for file in files_dicts
    ]

    return elements


async def add_reaction_if_enabled(event, emoji: str = "eyes"):
    if config.features.slack.reaction_on_message_received:
        try:
            await slack_app.client.reactions_add(
                channel=event["channel"], timestamp=event["ts"], name=emoji
            )
        except Exception as e:
            logger.warning(f"Failed to add reaction: {e}")


async def process_slack_message(
    event,
    say,
    thread_id: str,
    thread_name: Optional[str] = None,
    bind_thread_to_user=False,
    thread_ts: Optional[str] = None,
):
    await add_reaction_if_enabled(event)

    user = await get_user(event["user"])

    channel_id = event["channel"]

    text = event.get("text")
    slack_files = event.get("files", [])

    session_id = str(uuid.uuid4())
    session = HTTPSession(
        id=session_id,
        thread_id=thread_id,
        user=user,
        client_type="slack",
    )

    ctx = init_slack_context(
        session=session,
        slack_channel_id=channel_id,
        event=event,
        say=say,
        thread_ts=thread_ts,
    )

    file_elements = await download_slack_files(
        session, slack_files, slack_app.client.token
    )

    if on_chat_start := config.code.on_chat_start:
        await on_chat_start()

    msg = Message(
        content=clean_content(text),
        elements=file_elements,
        type="user_message",
        author=user.metadata.get("real_name"),
    )

    if on_message := config.code.on_message:
        await on_message(msg)

    if on_chat_end := config.code.on_chat_end:
        await on_chat_end()

    if data_layer := get_data_layer():
        user_id = None
        if isinstance(user, PersistedUser):
            user_id = user.id if bind_thread_to_user else None

        try:
            await data_layer.update_thread(
                thread_id=thread_id,
                name=thread_name or msg.content,
                metadata=ctx.session.to_persistable(),
                user_id=user_id,
            )
        except Exception as e:
            logger.error(f"Error updating thread: {e}")

    await ctx.session.delete()


@slack_app.event("app_home_opened")
async def handle_app_home_opened(event, say):
    pass


@slack_app.event("app_mention")
async def handle_app_mentions(event, say):
    thread_ts = event.get("thread_ts", event["ts"])
    thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))

    await process_slack_message(event, say, thread_id=thread_id, thread_ts=thread_ts)


@slack_app.event("message")
async def handle_message(message, say):
    thread_ts = message.get("thread_ts", message["ts"])
    thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))

    await process_slack_message(
        event=message,
        say=say,
        thread_id=thread_id,
        bind_thread_to_user=True,
        thread_ts=thread_ts,
    )


@slack_app.event("reaction_added")
async def handle_reaction_added(event):
    bot_id = await get_bot_user_id()

    if event.get("user") == bot_id:
        return

    item = event.get("item", {})
    channel_id = item.get("channel")
    thread_ts = item.get("ts")

    if not channel_id:
        logger.warning(
            "reaction_added event missing channel_id, skipping context setup"
        )
        return

    try:
        result = await slack_app.client.conversations_replies(
            channel=channel_id, ts=thread_ts, limit=1
        )

        if result.get("ok"):
            messages = result.get("messages")
            message = messages[0]
            message_user = message.get("user")
            message_bot_id = message.get("bot_id")

            if message_user != bot_id and message_bot_id != bot_id:
                return
        else:
            raise Exception(
                f"Failed to fetch message: {result.get('error', 'Unknown error')}"
            )

    except Exception as e:
        logger.warning(f"Failed to fetch message for reaction: {e}")
        return

    async def say(text: str = "", **kwargs):
        await slack_app.client.chat_postMessage(
            channel=channel_id, text=text, thread_ts=thread_ts, **kwargs
        )

    user = await get_user(event["user"])

    thread_id = (
        str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))
        if thread_ts
        else str(uuid.uuid4())
    )

    session_id = str(uuid.uuid4())
    session = HTTPSession(
        id=session_id,
        thread_id=thread_id,
        user=user,
        client_type="slack",
    )

    ctx = init_slack_context(
        session=session,
        slack_channel_id=channel_id,
        event=event,
        say=say,
        thread_ts=thread_ts,
    )

    try:
        if on_chat_start := config.code.on_chat_start:
            await on_chat_start()

        if on_slack_reaction_added := config.code.on_slack_reaction_added:
            await on_slack_reaction_added(event)
    finally:
        await ctx.session.delete()


@slack_app.block_action("thumbdown")
async def thumb_down(ack, context, body):
    await ack()
    step_id = body["actions"][0]["value"]
    thread_ts = body["message"]["thread_ts"]
    thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))

    if data_layer := get_data_layer():
        feedback = Feedback(forId=step_id, value=0, threadId=thread_id)
        await data_layer.upsert_feedback(feedback)

    text = body["message"]["text"]
    blocks = body["message"]["blocks"]
    updated_blocks = [block for block in blocks if block["type"] != "actions"]
    updated_blocks.append(
        {
            "type": "section",
            "text": {"type": "mrkdwn", "text": ":thumbsdown: Feedback received."},
        }
    )
    await context.client.chat_update(
        channel=body["channel"]["id"],
        ts=body["container"]["message_ts"],
        text=text,
        blocks=updated_blocks,
    )


@slack_app.block_action("thumbup")
async def thumb_up(ack, context, body):
    await ack()
    step_id = body["actions"][0]["value"]
    thread_ts = body["message"]["thread_ts"]
    thread_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, thread_ts))

    if data_layer := get_data_layer():
        feedback = Feedback(forId=step_id, value=1, threadId=thread_id)
        await data_layer.upsert_feedback(feedback)

    text = body["message"]["text"]
    blocks = body["message"]["blocks"]
    updated_blocks = [block for block in blocks if block["type"] != "actions"]
    updated_blocks.append(
        {
            "type": "section",
            "text": {"type": "mrkdwn", "text": ":thumbsup: Feedback received."},
        }
    )
    await context.client.chat_update(
        channel=body["channel"]["id"],
        ts=body["container"]["message_ts"],
        text=text,
        blocks=updated_blocks,
    )
